import halide as hl


def test_tuple_select():
    x = hl.Var("x")
    y = hl.Var("y")

    # ternary select with Expr condition
    f = hl.Func("f")
    f[x, y] = hl.select(x + y < 30, (x, y), (x - 1, y - 2))

    a, b = f.realize([200, 200])
    for xx in range(a.height()):
        for yy in range(a.width()):
            correct_a = xx if xx + yy < 30 else xx - 1
            correct_b = yy if xx + yy < 30 else yy - 2
            assert a[xx, yy] == correct_a
            assert b[xx, yy] == correct_b

    # ternary select with Tuple condition
    f = hl.Func("f")
    f[x, y] = hl.select((x < 30, y < 30), (x, y), (x - 1, y - 2))

    a, b = f.realize([200, 200])
    for xx in range(a.height()):
        for yy in range(a.width()):
            correct_a = xx if xx < 30 else xx - 1
            correct_b = yy if yy < 30 else yy - 2
            assert a[xx, yy] == correct_a
            assert b[xx, yy] == correct_b

    # multiway select with Expr condition
    f = hl.Func("f")
    # fmt: off
    f[x, y] = hl.select(x + y < 30,  (x, y),
                        x + y < 100, (x-1, y-2),
                                     (x-100, y-200))
    # fmt: on

    a, b = f.realize([200, 200])
    for xx in range(a.height()):
        for yy in range(a.width()):
            correct_a = xx if xx + yy < 30 else xx - 1 if xx + yy < 100 else xx - 100
            correct_b = yy if xx + yy < 30 else yy - 2 if xx + yy < 100 else yy - 200
            assert a[xx, yy] == correct_a
            assert b[xx, yy] == correct_b

    # multiway select with Tuple condition
    f = hl.Func("f")
    # fmt: off
    f[x, y] = hl.select((x < 30, y < 30),   (x, y),
                        (x < 100, y < 100), (x-1, y-2),
                                            (x-100, y-200))
    # fmt: on

    a, b = f.realize([200, 200])
    for xx in range(a.height()):
        for yy in range(a.width()):
            correct_a = xx if xx < 30 else xx - 1 if xx < 100 else xx - 100
            correct_b = yy if yy < 30 else yy - 2 if yy < 100 else yy - 200
            assert a[xx, yy] == correct_a
            assert b[xx, yy] == correct_b

    # Failure case: mixing Expr and Tuple in multiway
    try:
        f = hl.Func("f")
        # fmt: off
        f[x, y] = hl.select((x < 30, y < 30), (x, y),
                             x + y < 100,     (x-1, y-2),
                                              (x-100, y-200))
        # fmt: on
    except hl.HalideError as e:
        assert "select(): Expected Tuple but got Expr: ((x + y) < 100)" in str(e)
    else:
        assert False, "Did not see expected exception!"

    # Failure case: Tuples of mixed sizes
    try:
        f = hl.Func("f")
        f[x, y] = hl.select((x < 30, y < 30), (x, y, 0), (1, 2, 3, 4))
    except hl.HalideError as e:
        assert "select() on Tuples requires all Tuples to have identical sizes." in str(
            e
        )
    else:
        assert False, "Did not see expected exception!"


if __name__ == "__main__":
    test_tuple_select()
